{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import transforms\n",
    "from torchvision import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    " \n",
    "# prepare dataset\n",
    " \n",
    "batch_size = 64\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
    " \n",
    "train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)\n",
    "test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)\n",
    "test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)\n",
    " \n",
    "# design model using class\n",
    " \n",
    " \n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.pooling = torch.nn.MaxPool2d(2)\n",
    "        self.fc = torch.nn.Linear(320, 10)\n",
    " \n",
    " \n",
    "    def forward(self, x):\n",
    "        # flatten data from (n,1,28,28) to (n, 784)\n",
    "        \n",
    "        batch_size = x.size(0)\n",
    "        x = F.relu(self.pooling(self.conv1(x)))\n",
    "        x = F.relu(self.pooling(self.conv2(x)))\n",
    "        x = x.view(batch_size, -1) # -1 此处自动算出的是320\n",
    "        # print(\"x.shape\",x.shape)\n",
    "        x = self.fc(x)\n",
    " \n",
    "        return x\n",
    " \n",
    " \n",
    "model = Net()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    " \n",
    "# construct loss and optimizer\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n",
    " \n",
    "# training cycle forward, backward, update\n",
    " \n",
    " \n",
    "def train(epoch):\n",
    "    running_loss = 0.0\n",
    "    for batch_idx, data in enumerate(train_loader, 0):\n",
    "        inputs, target = data\n",
    "        inputs, target = inputs.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    " \n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    " \n",
    "        running_loss += loss.item()\n",
    "        if batch_idx % 300 == 299:\n",
    "            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))\n",
    "            running_loss = 0.0\n",
    " \n",
    " \n",
    "def test():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            images, labels = data\n",
    "            images, labels = images.to(device), labels.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, dim=1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    print('accuracy on test set: %d %% ' % (100*correct/total))\n",
    "    return correct/total\n",
    " \n",
    " \n",
    "if __name__ == '__main__':\n",
    "    epoch_list = []\n",
    "    acc_list = []\n",
    "    \n",
    "    for epoch in range(10):\n",
    "        train(epoch)\n",
    "        acc = test()\n",
    "        epoch_list.append(epoch)\n",
    "        acc_list.append(acc)\n",
    "    \n",
    "    plt.plot(epoch_list,acc_list)\n",
    "    plt.ylabel('accuracy')\n",
    "    plt.xlabel('epoch')\n",
    "    plt.show()\n",
    " \n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
